{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F \n",
    "\n",
    "class Attention(nn.Module): # single head for now, we'll add multi-head later!\n",
    "    def __init__(self, D=256): \n",
    "        super().__init__()\n",
    "        self.D = D\n",
    "        # divide by sqrt(D) to keep variance roughly constant, otherwise logits get too big\n",
    "        self.scale = torch.sqrt(torch.tensor(D, dtype=torch.float32))\n",
    "        # these are just linear projections that map input -> query/key/value vectors\n",
    "        self.wq = nn.Linear(D, D) # query projection\n",
    "        self.wk = nn.Linear(D, D) # key projection  \n",
    "        self.wv = nn.Linear(D, D) # value projection\n",
    "        self.wo = nn.Linear(D, D) # final output projection\n",
    "\n",
    "        # critical to understand that weights themselves are independent of seqlen, S\n",
    "        # so we can adaptively handle varying length sequences at inference time\n",
    "\n",
    "    def forward(self, x): # x is [B, S, D] (batch, sequence length, hidden dim)\n",
    "        # project input into Q,K,V vectors - each is [B, S, D]\n",
    "        Q, K, V = self.wq(x), self.wk(x), self.wv(x)\n",
    "        \n",
    "        # compute attention scores between each position - [B, S, D] @ [B, D, S] -> [B, S, S]\n",
    "        # scale to prevent softmax saturation which would kill gradients\n",
    "        A_logits = (Q @ K.transpose(1, 2))/self.scale\n",
    "        # this is the heart of attention, keys and values \"talking to each other\"\n",
    "        # and why people call attention a \"soft lookup\" since it's like matching a new key \n",
    "        # to all existing keys based on how similar that key is, rather than with eg. relational DBs\n",
    "        # where you need an exact match to return the value \n",
    "        \n",
    "        # convert scores to probabilities with softmax - each query attends to all keys\n",
    "        A = F.softmax(A_logits, dim=-1) # [B, S, S]\n",
    "        \n",
    "        # weighted sum of values based on attention probs\n",
    "        # [B, S, S] @ [B, S, D] -> [B, S, D], then project back to output space\n",
    "        return self.wo(A@V)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test passed! Output shape is correct.\n"
     ]
    }
   ],
   "source": [
    "# Test the attention module\n",
    "if __name__ == \"__main__\":\n",
    "    import torch.nn as nn\n",
    "    import torch.nn.functional as F\n",
    "    \n",
    "    # dummy inputs\n",
    "    batch_size, seq_len, hidden_dim = 2, 4, 256\n",
    "    x = torch.randn(batch_size, seq_len, hidden_dim)\n",
    "\n",
    "    attn = Attention(D=hidden_dim)\n",
    "    \n",
    "    # run the forward pass\n",
    "    out = attn(x)\n",
    "    \n",
    "    # check shapes\n",
    "    assert out.shape == (batch_size, seq_len, hidden_dim), f\"Expected shape {(batch_size, seq_len, hidden_dim)} but got {out.shape}\"\n",
    "    print(\"Test passed! Output shape is correct.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "envi",
   "language": "python",
   "name": "lingua_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
